import torch
import torch.nn as nn

'''
Define Attacks
'''

def fgsm(model, X, y, epsilon):
    """ Construct FGSM adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    loss = nn.BCEWithLogitsLoss()(model(X + delta)[:,0], y.float())
    loss.backward()
    return epsilon * delta.grad.detach().sign()

def rfgsm(model, X, y, epsilon):
    """ Construct R-FGSM adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    loss = -nn.BCEWithLogitsLoss()(-model(X + delta)[:,0], y.float())
    loss.backward()
    return epsilon * delta.grad.detach().sign()


def norms(Z):
    """Compute norms over all but the first dimension"""
    return Z.view(Z.shape[0], -1).norm(dim=1)[:,None,None,None]

def pgd(model, X, y, epsilon, alpha=100, num_iter=60):
    """ Construct PGD-Inf adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = nn.BCEWithLogitsLoss()(model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data = (delta + X.shape[0]*alpha*delta.grad.data).clamp(-epsilon,epsilon)
        delta.grad.zero_()
        if t % 10 == 9: alpha /= 5
    return delta.detach()

def rpgd(model, X, y, epsilon, alpha=100, num_iter=60):
    """ Construct R-PGD-Inf adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = -nn.BCEWithLogitsLoss()(-model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data = (delta + X.shape[0]*alpha*delta.grad.data).clamp(-epsilon,epsilon)
        delta.grad.zero_()
        if t % 10 == 9: alpha /= 5
    return delta.detach()


def pgd2(model, X, y, epsilon, alpha=100, num_iter=60):
    """ Construct PGD-2 adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = nn.BCEWithLogitsLoss()(model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data += alpha*delta.grad.detach()
        delta.data *= epsilon*20 / norms(delta.detach()).clamp(min=epsilon*20)
        delta.data = torch.min(torch.max(delta.detach(), -X), 1-X) # clip X+delta to [0,1]
        delta.grad.zero_() 
        if t % 10 == 9: alpha /= 5
    return delta.detach()

def rpgd2(model, X, y, epsilon, alpha=100, num_iter=60):
    """ Construct R-PGD-2 adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = -nn.BCEWithLogitsLoss()(-model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data += alpha*delta.grad.detach()
        delta.data *= epsilon*20 / norms(delta.detach()).clamp(min=epsilon*20)
        delta.data = torch.min(torch.max(delta.detach(), -X), 1-X) # clip X+delta to [0,1]
        delta.grad.zero_() 
        if t % 10 == 9: alpha /= 5
    return delta.detach()

def pgd_linf(model, X, y, epsilon, alpha=100, num_iter=60, randomize=False):
    """ Construct BIM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(X, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(X, requires_grad=True)
        
    for t in range(num_iter):

        loss = nn.BCEWithLogitsLoss()(model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
        if t % 10 == 9: alpha /= 5
    return delta.detach()

def rpgd_linf(model, X, y, epsilon, alpha=100, num_iter=60, randomize=False):
    """ Construct R-BIM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(X, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(X, requires_grad=True)
        
    for t in range(num_iter):

        loss = -nn.BCEWithLogitsLoss()(-model(X + delta)[:,0], y.float())
        loss.backward()
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
        if t % 10 == 9: alpha /= 5
    return delta.detach()
